//
// Created by xyg on 2022/3/18.
//

#ifndef SJWL_NET_HPP
#define SJWL_NET_HPP
#include "sjwl"
#include "layer.hpp"
#include "tensor.hpp"
#include <cstdarg>
#include "model.hpp"


namespace nn {
    class net {
    private:
        int _layer_count;
        Layerlist _layer_list;
    public:
        net(int layers,LayerBase *l1 = NULL, ...):_layer_count(0){
            if (l1 == NULL)return;
            va_list arg_ptr;
            va_start(arg_ptr, l1);
            _layer_list.push_back(l1);
            _layer_count = layers;
            while (--layers) {
                LayerBase *args = va_arg(arg_ptr, LayerBase *);
                _layer_list.push_back(args);
            }
        };
        ~net(){};
        const Layerlist& get_layerlist(){return _layer_list;}
        Dmatrix forward(const Dmatrix &x) {
            Dmatrix y;
            Dmatrix tmp = x;
            for (auto i: _layer_list) {
                y = i->forward(tmp);
                tmp = y;
            }
            return y;
        }
        Tensorlist *backward(const Dtensor &grad) {
            auto *grads = new Tensorlist(this->_layer_count);
            auto il = _layer_list.rbegin();
            auto it = grads->rbegin();

            if (this->_layer_count == 0)
                return grads;
            Dtensor tmp = grad;
            while (il != _layer_list.rend()) {
                *it = new Dtensor((*il)->get_grad());
                tmp = (*il)->backward(tmp);
                it++;
                il++;
            }
            return grads;
        }
        void add(LayerBase *l) noexcept{
            _layer_list.push_back(l);
            _layer_count++;
        }
    };

}
#endif //SJWL_NET_HPP
